{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matchzoo as mz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n",
    "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filter=True)\n",
    "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filter=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8686.08it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:03<00:00, 4822.29it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 744204.86it/s]\n",
      "Building FrequencyFilterUnit from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 131779.52it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 70505.55it/s] \n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 587341.21it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 662927.05it/s]\n",
      "Building VocabularyUnit from a datapack.: 100%|██████████| 404415/404415 [00:00<00:00, 2777924.27it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2118/2118 [00:00<00:00, 8979.64it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 18841/18841 [00:03<00:00, 4812.69it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 121059.54it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 161389.72it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 121666.21it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 547960.52it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 704196.06it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 115637.91it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 89687.26it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.BasicPreprocessor(fixed_length_left=10, fixed_length_right=40, remove_stop_words=False)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 122/122 [00:00<00:00, 8269.58it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 1115/1115 [00:00<00:00, 4815.67it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 123590.09it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 95094.79it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 123287.08it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 160258.41it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 529152.41it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 82162.02it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 75361.75it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 237/237 [00:00<00:00, 8641.43it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit: 100%|██████████| 2300/2300 [00:00<00:00, 4839.20it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 132822.51it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 149796.57it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 134599.76it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 246601.35it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 615510.70it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 94008.89it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 87378.96it/s]\n"
     ]
    }
   ],
   "source": [
    "valid_pack_processed = preprocessor.transform(valid_pack)\n",
    "predict_pack_processed = preprocessor.transform(predict_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankHingeLoss())\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter \"name\" set to ConvKNRM.\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.ConvKNRM()\n",
    "model.params['input_shapes'] = preprocessor.context['input_shapes']\n",
    "model.params['task'] = ranking_task\n",
    "model.params['embedding_input_dim'] = preprocessor.context['vocab_size']\n",
    "model.params['embedding_output_dim'] = 300\n",
    "model.params['embedding_trainable'] = True\n",
    "model.params['filters'] = 128 \n",
    "model.params['conv_activation_func'] = 'tanh' \n",
    "model.params['max_ngram'] = 3\n",
    "model.params['use_crossmatch'] = True \n",
    "model.params['kernel_num'] = 11\n",
    "model.params['sigma'] = 0.1\n",
    "model.params['exact_sigma'] = 0.001\n",
    "model.params['optimizer'] = 'adadelta'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "#model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_embedding = mz.datasets.embeddings.load_glove_embedding(dimension=300)\n",
    "embedding_matrix = glove_embedding.build_matrix(preprocessor.context['vocab_unit'].state['term_index'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_embedding_matrix(embedding_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_x, pred_y = predict_pack_processed[:].unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "255"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=5, num_neg=1, batch_size=20)\n",
    "len(train_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/30\n",
      "255/255 [==============================] - 33s 129ms/step - loss: 0.4178\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5644249595231345 - normalized_discounted_cumulative_gain@5(0): 0.6181641523866899 - mean_average_precision(0): 0.5820818624479569\n",
      "Epoch 2/30\n",
      "255/255 [==============================] - 27s 106ms/step - loss: 0.0799\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5645181071040578 - normalized_discounted_cumulative_gain@5(0): 0.6283527787570147 - mean_average_precision(0): 0.5800770146746018\n",
      "Epoch 3/30\n",
      "255/255 [==============================] - 27s 104ms/step - loss: 0.0372\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.594064247710109 - normalized_discounted_cumulative_gain@5(0): 0.6443493868479856 - mean_average_precision(0): 0.5967726003485497\n",
      "Epoch 4/30\n",
      "255/255 [==============================] - 26s 104ms/step - loss: 0.0215\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5804273606919105 - normalized_discounted_cumulative_gain@5(0): 0.6440750987882721 - mean_average_precision(0): 0.5988759656640628\n",
      "Epoch 5/30\n",
      "255/255 [==============================] - ETA: 0s - loss: 0.016 - 25s 100ms/step - loss: 0.0165\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5726307929460168 - normalized_discounted_cumulative_gain@5(0): 0.6193040940724813 - mean_average_precision(0): 0.5668319701175248\n",
      "Epoch 6/30\n",
      "255/255 [==============================] - 26s 100ms/step - loss: 0.0114\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5840929435061598 - normalized_discounted_cumulative_gain@5(0): 0.6423600770752087 - mean_average_precision(0): 0.5958907514354033\n",
      "Epoch 7/30\n",
      "255/255 [==============================] - 26s 103ms/step - loss: 0.0062\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5855763878155281 - normalized_discounted_cumulative_gain@5(0): 0.6420936861470886 - mean_average_precision(0): 0.5952894390226855\n",
      "Epoch 8/30\n",
      "255/255 [==============================] - 25s 100ms/step - loss: 0.0028\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5973980323250132 - normalized_discounted_cumulative_gain@5(0): 0.6452809363595082 - mean_average_precision(0): 0.6027186992842243\n",
      "Epoch 9/30\n",
      "255/255 [==============================] - 26s 102ms/step - loss: 0.0016\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6045888272990727 - normalized_discounted_cumulative_gain@5(0): 0.6480621626853258 - mean_average_precision(0): 0.6079196547963747\n",
      "Epoch 10/30\n",
      "255/255 [==============================] - 26s 100ms/step - loss: 0.0013\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5960954150072109 - normalized_discounted_cumulative_gain@5(0): 0.6345872185641286 - mean_average_precision(0): 0.6026882672658678\n",
      "Epoch 11/30\n",
      "255/255 [==============================] - 25s 98ms/step - loss: 0.0018\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5906357728003356 - normalized_discounted_cumulative_gain@5(0): 0.6420899565939712 - mean_average_precision(0): 0.5959738010558296\n",
      "Epoch 12/30\n",
      "254/255 [============================>.] - ETA: 0s - loss: 4.6736e-04Validation: normalized_discounted_cumulative_gain@3(0): 0.5906357728003356 - normalized_discounted_cumulative_gain@5(0): 0.6420899565939712 - mean_average_precision(0): 0.5959738010558296\n",
      "Epoch 12/30\n",
      "255/255 [==============================] - 26s 102ms/step - loss: 4.6553e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5792108452282783 - normalized_discounted_cumulative_gain@5(0): 0.6474050081303896 - mean_average_precision(0): 0.6022366688529198\n",
      "Epoch 13/30\n",
      "255/255 [==============================] - 26s 101ms/step - loss: 3.6288e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5976014226835163 - normalized_discounted_cumulative_gain@5(0): 0.650869011104818 - mean_average_precision(0): 0.6121357537523721\n",
      "Epoch 14/30\n",
      "255/255 [==============================] - 26s 102ms/step - loss: 5.6768e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6100119914582753 - normalized_discounted_cumulative_gain@5(0): 0.6620673553891757 - mean_average_precision(0): 0.6199898556304974\n",
      "Epoch 15/30\n",
      "255/255 [==============================] - 25s 98ms/step - loss: 7.6058e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5954680993505707 - normalized_discounted_cumulative_gain@5(0): 0.6501858660448717 - mean_average_precision(0): 0.6093861578286818\n",
      "Epoch 16/30\n",
      "255/255 [==============================] - 25s 97ms/step - loss: 4.6683e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5995135649053625 - normalized_discounted_cumulative_gain@5(0): 0.6555209793595573 - mean_average_precision(0): 0.6103261879917136\n",
      "Epoch 17/30\n",
      "255/255 [==============================] - 25s 97ms/step - loss: 4.9290e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6074351481286072 - normalized_discounted_cumulative_gain@5(0): 0.6578611598638955 - mean_average_precision(0): 0.615852195289252\n",
      "Epoch 18/30\n",
      "255/255 [==============================] - 25s 98ms/step - loss: 7.0720e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.601194670403182 - normalized_discounted_cumulative_gain@5(0): 0.6633953120087637 - mean_average_precision(0): 0.6152177876279701\n",
      "Epoch 19/30\n",
      "255/255 [==============================] - 25s 98ms/step - loss: 3.8354e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6058265752632658 - normalized_discounted_cumulative_gain@5(0): 0.6581053207886893 - mean_average_precision(0): 0.6216270443932395\n",
      "Epoch 20/30\n",
      "255/255 [==============================] - 25s 99ms/step - loss: 1.5787e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5996976217973234 - normalized_discounted_cumulative_gain@5(0): 0.6479215804564739 - mean_average_precision(0): 0.6125724210074839\n",
      "Epoch 21/30\n",
      "255/255 [==============================] - 26s 102ms/step - loss: 4.0510e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6079716038586833 - normalized_discounted_cumulative_gain@5(0): 0.6529132296649792 - mean_average_precision(0): 0.6126753899557325\n",
      "Epoch 22/30\n",
      "255/255 [==============================] - 25s 98ms/step - loss: 4.1566e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6100971365244625 - normalized_discounted_cumulative_gain@5(0): 0.6597477594228381 - mean_average_precision(0): 0.6158498393091412\n",
      "Epoch 23/30\n",
      "255/255 [==============================] - 26s 103ms/step - loss: 3.8046e-04ion: normalized_discounted_cumulative_gain@3(0): 0.6100971365244625 - normalized_discounted_cumulative_gain@5(0): 0.6597477594228381 - mean_average_precision(0): 0.61584983930914\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6194486401972197 - normalized_discounted_cumulative_gain@5(0): 0.6730401801927124 - mean_average_precision(0): 0.6289384905047652\n",
      "Epoch 24/30\n",
      "255/255 [==============================] - 26s 100ms/step - loss: 0.0011\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5961296836366201 - normalized_discounted_cumulative_gain@5(0): 0.6517440928706311 - mean_average_precision(0): 0.6091824689917628\n",
      "Epoch 25/30\n",
      "255/255 [==============================] - 25s 98ms/step - loss: 0.0013\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6218643991550562 - normalized_discounted_cumulative_gain@5(0): 0.675140057306235 - mean_average_precision(0): 0.6316967457695356\n",
      "Epoch 26/30\n",
      "255/255 [==============================] - 26s 102ms/step - loss: 2.8035e-04\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6093550173234711 - normalized_discounted_cumulative_gain@5(0): 0.6668778990186204 - mean_average_precision(0): 0.62523186161927\n",
      "Epoch 27/30\n",
      "255/255 [==============================] - 25s 98ms/step - loss: 5.3680e-05\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6209981079538025 - normalized_discounted_cumulative_gain@5(0): 0.6683650496544457 - mean_average_precision(0): 0.6320168909359449\n",
      "Epoch 28/30\n",
      "255/255 [==============================] - 25s 98ms/step - loss: 0.0000e+00\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6209981079538025 - normalized_discounted_cumulative_gain@5(0): 0.6683650496544457 - mean_average_precision(0): 0.6320168909359449\n",
      "Epoch 29/30\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "255/255 [==============================] - 25s 98ms/step - loss: 0.0000e+00\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6209981079538025 - normalized_discounted_cumulative_gain@5(0): 0.6683650496544457 - mean_average_precision(0): 0.6320168909359449\n",
      "Epoch 30/30\n",
      "255/255 [==============================] - 25s 96ms/step - loss: 0.0000e+00\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.6209981079538025 - normalized_discounted_cumulative_gain@5(0): 0.6683650496544457 - mean_average_precision(0): 0.6320168909359449\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=30, callbacks=[evaluate], workers=30, use_multiprocessing=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
